#!/usr/bin/env python
#-*- encoding:utf-8 -*-

import sys
import socket
import gevent
from gevent import monkey
import logging
#import color_log
import signal

monkey.patch_all()

serv = None
udp_serv = None
udp_remote = None

SERVER_PORT = 1099
SOCKET_TIMEOUT = 300

Sessions = 0
UpBytes = 0
DownBytes = 0

def domain_to_addr(host, port):
    try:
        addrs = socket.getaddrinfo(host, port, socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
    except:
        return None
    if (len(addrs) == 0):
        return None
    else:
        return addrs[0][4]


class UDPSessionMgr:
    def __init__(self):
        self.sessions_ = []
    
    def count(self):
        return len(self.sessions_)
    
    def update(self, local_addr, remote_addr, addr_data):
        for i in xrange(len(self.sessions_)):
            _, y, _ = self.sessions_[i]
            if (y[0] == remote_addr[0]) and (y[1] == remote_addr[1]):
                self.sessions_[i] = (local_addr, remote_addr, addr_data)
                return
        self.sessions_.append((local_addr, remote_addr, addr_data))

    def findLocalAddr(self, remote_addr):
        for i in xrange(len(self.sessions_)):
            x, y, _ = self.sessions_[i]
            if (y[0] == remote_addr[0]) and (y[1] == remote_addr[1]):
                return x
        return None

udp_sessions = UDPSessionMgr()


class PROXY_STATUS:
    ssInit = 0
    ssLogin = 1
    ssPairOk = 2
    ssFailure = 3


def BytesToStr(nBytes):
    KB = 1024
    MB = 1024 * 1024
    GB = 1024 * 1024 * 1024
    if (nBytes >= GB):
        return "%.2fG" % (float(nBytes)/ GB)
    elif (nBytes >= MB):
        return "%.2fM" % (float(nBytes)/ MB)
    elif (nBytes >= KB):
        return "%.2fK" % (float(nBytes)/ KB)
    else:
        return "%d Bytes" % (nBytes)    


def init_logger():
    FORMAT = '%(asctime)s %(levelname)-4s [%(filename)s:%(lineno)s-%(funcName)s] %(message)s'
    logging.basicConfig(format=FORMAT)
    logging.getLogger().setLevel(logging.DEBUG)
    logging.addLevelName(logging.DEBUG, "DEBG")
    logging.addLevelName(logging.INFO, 'INFO')
    logging.addLevelName(logging.WARNING, "WARN")
    logging.addLevelName(logging.ERROR, "ERRR")    

    fh = logging.FileHandler("test.log")
    fh.setLevel(logging.DEBUG) 
    fh.setFormatter(logging.Formatter(FORMAT))

    logging.getLogger().addHandler(fh)

def udp_remote_proxy():
    global udp_remote
    global udp_serv
    global udp_sessions

    while (True):
        data, remote_addr = udp_remote.recvfrom(8192)
        if len(data) == 0:
            logging.warn("recv error data len = 0")
            udp_remote.close()
            break
        
        local_addr = udp_sessions.findLocalAddr(remote_addr)
        if (local_addr == None):
            logging.warn("remote recv but not session")
            continue
        
        logging.info("udp recv %s => %s", str(remote_addr), str(local_addr))
        ip = remote_addr[0].split('.')
        #pack_hdr = '\x00\x00\x00\x01' + chr(int(ip[0])) + chr(int(ip[1])) + chr(int(ip[2])) + chr(int(ip[3])) 
        pack_hdr = '\x00\x00\x00\x01' + remote_addr[2]
        pack_hdr = pack_hdr + chr(local_addr[1] >> 8) + chr(local_addr[1] & 0xff)
        send_data = pack_hdr + data
        udp_serv.sendto(send_data, local_addr)


def udp_proxy():
    global udp_serv
    global udp_remote
    global udp_sessions

    while True:
        try:
            data, from_addr = udp_serv.recvfrom(8192)
        except Exception, err:
            logging.warn("udp proxy exception " +str(err))
            udp_serv.close()
            break

        if len(data) == 0:
            logging.warn("udp proxy recv data len = 0")
            udp_serv.close()
            break
        
        if len(data) < 10:
            continue

        to_addr = ""
        to_port = 0
        data_offset = 0
        addr_offset = 0
        addr_len = 0
        if (data[3] == "\x01"):
            to_addr = "%d.%d.%d.%d" % (ord(data[4]), ord(data[5]), ord(data[6]), ord(data[7]))
            to_port = (ord(data[8]) << 8) + ord(data[9])
            data_offset = 10
            addr_offset = 3
            addr_len = 7
        elif (data[3] == '\x03'):
            to_addr = data[5:5+ ord(data[4])]
            to_port = (ord(data[5+ord(data[4])]) << 8) + ord(data[6+ ord(data[4])])
            data_offset = 5 + ord(data[4]) + 2
            addr_offset = 3
            addr_len = 2 + ord(data[4]) + 2;
        else:
            logging.warn("warnning udp protocol %s", str(from_addr))
            continue
        
        logging.info("udp send data %s => %s:%d", str(from_addr), to_addr, to_port)

        dest_addr = domain_to_addr(to_addr, to_port)
        if (dest_addr == None):
            logging.warn("domain to ip failure %s:%d", to_addr, to_port)
            continue
        
        logging.info("udp send data %s => %s", str(from_addr), str(dest_addr))
        udp_sessions.update(from_addr, dest_addr, data[addr_offset:addr_offset+addr_len])
        udp_remote.sendto(data[data_offset:], dest_addr)

    print "break udp proxy"


def remote_proc(remote, client, remote_addr, client_addr):
    global DownBytes
    try:
        while (True):
            data = remote.recv(4096)
            if len(data) == 0:
                logging.info("remote disconnect %s, %s", remote_addr, client_addr)
                remote.close()
                client.close()
                break
            client.send(data)
            DownBytes += len(data)
    except:
        try:
            client.shutdown(1)
            client.close()
        except:
            pass
        try:
            remote.shutdown(1)
            remote.close()
        except:
            pass


def proxy_proc(client, addr):
    global SOCKET_TIMEOUT
    global UpBytes
    global Sessions
    global udp_serv
    stat = PROXY_STATUS.ssInit  # proxy stat
    local_addr = addr
    remote = None  # remote Socket
    remote_addr = ""
    remote_port = 0

    Sessions += 1

    while(True):
        try:
            data = client.recv(4096)
        except:
            client.close()
            logging.info("client disconnected %s", str(local_addr))
            if (remote != None):
                remote.close()
            break

        if (len(data) == 0):
            logging.info("client disconnected %s", str(local_addr))
            client.close()
            if (remote !=None):
                remote.close()
            break
        
        if (stat == PROXY_STATUS.ssInit):
            if (data[0] == '\x05'):
                client.send('\x05\x00')
                stat = PROXY_STATUS.ssLogin
            else:
                logging.warn("unkown command cmd = %d, %s", ord(data[0]), str(local_addr))
                client.shutdown(1)
                client.close()
                break

        elif (stat == PROXY_STATUS.ssLogin):
            if ((data[0] == '\x05') and (data[1] == '\x01')):
                atype = ord(data[3])
                if (atype == 1):  # ipv4
                    remote_addr = "%d.%d.%d.%d" % (ord(data[4]), ord(data[5]), ord(data[6]), ord(data[7]))
                    remote_port = (ord(data[8]) << 8) + ord(data[9])
                elif (atype == 3):
                    remote_addr = data[5:5+ ord(data[4])]
                    remote_port = (ord(data[5+ord(data[4])]) << 8) + ord(data[6+ ord(data[4])])
                else:
                    logging.warn("error addr type atype = %d, %s", atype, str(local_addr))
                    #data[1] = '\x01'
                    data = data[:1] + '\x01' + data[2:]
                    client.send(data)
                    stat = PROXY_STATUS.ssFailure
                    break
                
                if (remote_addr <> "" and remote_port <> 0):
                    logging.info("%s connect to %s:%d", str(local_addr), remote_addr, remote_port)
                    remote = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
                    try:
                        remote.connect((remote_addr, remote_port))
                    except:
                        logging.info("%s remote connect failure %s:%d", str(local_addr), remote_addr, remote_port)
                        remote.close()
                        #data[1] = '\x05'
                        data = data[:1] + '\x05' + data[2:]
                        client.send(data)
                        #client.shutdown(1)
                        stat = PROXY_STATUS.ssFailure
                        client.close()
                        break
                    
                    remote.settimeout(SOCKET_TIMEOUT)
                    logging.info("%s remote connect success %s:%d", str(local_addr), remote_addr, remote_port)
                    data = data[:1] + '\x00' + data[2:]
                    client.send(data)

                    span = gevent.spawn(remote_proc, remote, client, "(%s:%d)" % (remote_addr, remote_port), str(local_addr))
                    stat = PROXY_STATUS.ssPairOk
            elif ((data[0] == '\x05') and (data[1] == '\x03')): #UDP
                _, udp_port = udp_serv.getsockname()

                udp_local_addr = (client.getsockname()[0], udp_port)

                ip = udp_local_addr[0].split('.')
                #reply udp ip and port 
                data = "\x05\x00\x00\x01" + chr(int(ip[0])) + chr(int(ip[1])) + chr(int(ip[2])) + chr(int(ip[3]))
                data = data + chr(udp_port >> 8) + chr(udp_port & 0xff)
                client.settimeout(None)
                client.send(data)
                #break
            else:
                logging.warn("%s unkown protocol cmd = %d, %d", str(local_addr), data[0], data[1])
                client.shutdown(1)
                client.close()
                break

        elif (stat == PROXY_STATUS.ssPairOk):
            remote.send(data)
            UpBytes += len(data)

        elif (stat == PROXY_STATUS.ssFailure):
            logging.warn("%s WTF statue failure", str(local_addr))
            client.close()
            break
    
    Sessions -= 1


def accept_proc(sckt):
    global SOCKET_TIMEOUT
    while (True):
        try:
            client, addr = sckt.accept()
            client.settimeout(SOCKET_TIMEOUT)
        except:
            break

        logging.info("client accept addr = %s", addr)
        gevent.spawn(proxy_proc, client, addr)


def sigint_handler(signum, frame):
    global serv
    print ""
    logging.warn("catch interrupte to eixt...")
    serv.close()

def print_info():
    global Sessions
    global UpBytes
    global DownBytes
    while (True):
        upstr = BytesToStr(UpBytes)
        dnstr = BytesToStr(DownBytes)

        print "Session = %d, UpBytes = %s, DownBytes = %s" % (Sessions, upstr, dnstr)
        gevent.sleep(5)

def main(argv):
    global serv
    global SERVER_PORT
    global udp_serv
    global udp_remote

    init_logger()
    logging.info("=========== start =============")

    signal.signal(signal.SIGINT, sigint_handler)

    serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
    serv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)

    try:
        serv.bind(("0.0.0.0", SERVER_PORT))
        serv.listen(5)
    except Exception, err:
        logging.error("listen failure! %s" % (str(err)))
        return

    logging.info("listen on %d", SERVER_PORT)

    thd = gevent.spawn(accept_proc, serv)
    gevent.spawn(print_info)


    udp_serv = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
    udp_serv.bind(("", 0))

    udp_remote = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
    udp_remote.bind(("", 0))

    gevent.spawn(udp_proxy)
    gevent.spawn(udp_remote_proxy)

    thd.join()

    serv.close()

    logging.info("END!")

if __name__  == "__main__":
    main(sys.argv)
